"""Entry point."""

import argparse
import random
import time
import torch
import numpy as np
import os

from exps.controller import Controller, ControllerTrainer
from exps.nasbenchs.nasbench101 import Nasbench101
from exps.nasbenchs.nasbench201 import Nasbench201
from exps.nasbenchs.natsbench import Natsbench
from exps.predictors.acquisition_functions import acquisition_function
from exps.predictors.train_predictor import PredictorModel
from exps.predictors.transformer import TransformerPredictor
from exps.utils import Logger, get_logger, encoded_arch

from exps.predictors.ensemble import Ensemble


def build_args():
    parser = argparse.ArgumentParser(description='PNASM')
    register_default_args(parser)
    args = parser.parse_args()

    return args


def register_default_args(parser):
    # ------------------------------- Search space and dataset ---------------------------------
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["cifar10", "cifar100", "ImageNet16-120"],
        help="Choose between Cifar10/100 and ImageNet-16.",
    )
    parser.add_argument("--bench_name", type=str, default="natsbench")
    parser.add_argument("--search_space", type=str, default="tss")

    # ------------------------------- Training settings ---------------------------------
    parser.add_argument("--rand_seed", type=int, help="manual seed", default=-1)
    parser.add_argument("--loops_if_rand", type=int, default=20)
    parser.add_argument("--time_budget", type=float, default=40000, help="CIFAR10-40000, CIFAR100-75000, ImageNet16-120-200000")
    parser.add_argument("--workers", type=int, default=2, help="number of data loading workers (default: 2)")
    parser.add_argument("--num_batch_per_epoch", type=int, default=5)
    parser.add_argument("--fixedk", type=int, default=2, help="Recommended value: cifar-10_2, cifar-100_5, ImageNet16-120_2")

    # ------------------------------- Search strategy and agent ---------------------------------
    parser.add_argument("--episodes", type=int, default=20, help="Sample the number of archs per controller\'s update")
    parser.add_argument("--update_policy_epochs", type=int, default=5)
    parser.add_argument("--update_agent_algo", type=str, default="ppo")
    parser.add_argument("--clip_ratio", type=float, default=0.2)

    parser.add_argument("--controller_epochs", type=int, default=40)
    parser.add_argument("--controller_learning_rate", type=float, default=0.001, help="learning rate for arch encoding")
    parser.add_argument("--controller_weight_decay", type=float, default=1e-3, help="weight decay for arch encoding")
    parser.add_argument("--controller_eps", type=float, default=1e-8, help="weight decay for arch encoding")

    # ------------------------------- Predictor ---------------------------------
    parser.add_argument("--is_predictor", type=str, default="True")
    parser.add_argument("--predictor_type", type=str, default="seminas")
    parser.add_argument("--init_epoch_for_predictor", type=int, default=1, help="Sample 100 archs per epoch.")
    parser.add_argument("--predictor_mode", type=str, default="fixed_k", help="including: all, fixed_k and adaptive.")

    parser.add_argument("--is_ensemble", type=int, default=1)
    parser.add_argument("--num_ensemble", type=int, default=3)

    parser.add_argument("--theta", type=float, default=0.1)

    # ------------------------------- log and path ---------------------------------
    parser.add_argument("--log_path", type=str, default="./outputs/")
    parser.add_argument("--nasbench_data_path", type=str, default="./exps/nasbenchs/bench_dataset")


def prepare_seed(rand_seed):
    random.seed(rand_seed)
    np.random.seed(rand_seed)
    torch.manual_seed(rand_seed)
    torch.cuda.manual_seed(rand_seed)
    torch.cuda.manual_seed_all(rand_seed)


def train_predictor(args, controller_trainer, nas_bench, random_arch):
    start_time = time.time()
    if not args.is_ensemble:
        encoded_arch_data = encoded_arch(
            nas_bench,
            controller_trainer.buffer_unique_archs["true_info"],
            True,
            args.dataset
        )
        controller_trainer.predictor.fit(encoded_arch_data, epochs=200)

        acq_fn = controller_trainer.predictor
    else:
        xtrain = [v[0] for v in controller_trainer.buffer_unique_archs["true_info"]]
        ytrain = [v[1] for v in controller_trainer.buffer_unique_archs["true_info"]]

        if args.predictor_type == "seminas":
            unlabeled = []
            # create unlabeled data and pass it to the predictor
            while len(unlabeled) < len(xtrain):
                arch = controller_trainer.get_unique_arch_by_random(random_arch)
                unlabeled.append(arch.tostr())
            controller_trainer.predictor.set_pre_computations(
                unlabeled=unlabeled
            )

        train_error = controller_trainer.predictor.fit(xtrain, ytrain)

        # define an acquisition function
        acq_fn = acquisition_function(
            ensemble=controller_trainer.predictor, ytrain=None, acq_fn_type="exploit_only"
        )

    controller_trainer.acq_fn = acq_fn
    controller_trainer.cur_total_time_costs += time.time() - start_time


def main(args, base_path):
    torch.set_num_threads(args.workers)
    prepare_seed(args.rand_seed)

    output_path = base_path + '/' + str(args.rand_seed)
    log = get_logger(output_path)
    log.info(args)

    random_arch = None
    # Currently, only support nasbench201(NATSBech).
    if args.bench_name == "nasbench101":
        # nas_bench = Nasbench101(args.nasbench_data_path)
        pass
    elif args.bench_name == "nasbench201" or args.bench_name == "natsbench":
        # nas_bench = Nasbench201(args.nasbench_data_path)
        nas_bench = Natsbench(args.nasbench_data_path, args.search_space)
        random_arch = nas_bench.random_topology_func(nas_bench.op_names)
        args.bench_name = "nasbench201"
    else:
        raise Exception("Invalid bench name:", args.bench_name)

    print(20 * "-" + "Setting parameters of controller." + 20 * '-')
    controller = Controller(nas_bench.edge2index, nas_bench.op_names, nas_bench.max_nodes)
    optimizer = torch.optim.Adam(
        controller.parameters(),
        lr=args.controller_learning_rate,
        betas=(0.5, 0.999),
        weight_decay=args.controller_weight_decay,
        eps=args.controller_eps,
    )
    if torch.cuda.is_available():
        controller = controller.cuda()

    print(20 * "-" + "Setting parameters of predictor." + 20 * "-")
    if args.is_predictor == "True":
        if not args.is_ensemble:
            if args.predictor_type == 'transformer':
                adj_type = "lapla"

                if args.bench_name == "nasbench101":
                    n_src_vocab = 5
                    pos_enc_dim = 7
                elif args.bench_name == "nasbench201" or args.bench_name == "natsbench":
                    n_src_vocab = 7
                    pos_enc_dim = 8
                else:
                    raise Exception("Invalid bench name: ", args.bench_name)

                predictor = TransformerPredictor(
                    adj_type=adj_type,
                    n_src_vocab=n_src_vocab,
                    pos_enc_dim=pos_enc_dim,
                    bench=args.bench_name
                )

            else:
                raise Exception("Invalid value", args.predictor_type)

        else:
            predictor = Ensemble(
                num_ensemble=args.num_ensemble,
                ss_type=args.bench_name,
                predictor_type=args.predictor_type,
                # config=self.config,
                dataset=args.dataset,
                bench_api=nas_bench,
            )
    else:
        predictor = None

    controller_trainer = ControllerTrainer(controller, optimizer, predictor, nas_bench, args, log)

    log.info(20 * "-" + "Initializing predictor" + 20 * "-")
    epoch = 0
    log.info(20 * "-" + "Collect data" + 20 * "-")
    if args.is_predictor == "True":
        temp = args.predictor_mode
        args.predictor_mode = "None"
        args.is_predictor = "False"
        controller_trainer.pre_train_controller()
        args.predictor_mode = temp
        args.is_predictor = "True"
        epoch += 1

    log.info(20 * "-" + "Train the predictor with collected data." + 20 * "-")
    if args.is_predictor == "True":
        # we refer to the predictor seminas.
        train_predictor(args, controller_trainer, nas_bench, random_arch)
    log.info(20 * "-" + "Finish Initializing predictor" + 20 * "-")

    log.info(20 * "-" + "Searching" + 20 * "-")
    # using the predictor is low time cost. The epoch was used to avoid the infinite loops.
    while controller_trainer.cur_total_time_costs < args.time_budget and epoch < args.controller_epochs:
        start_time = time.time()
        epoch += 1

        # 100 steps per epoch.
        if args.predictor_mode == "adaptive":
            controller_trainer.pre_train_controller_by_kl()
        else:
            controller_trainer.pre_train_controller()

        if controller_trainer.cur_total_time_costs > args.time_budget:
            break

        if args.predictor_mode in ["fixed_k", "adaptive"]:
            log.info(20 * "-" + "Retrain the predictor" + 20 * "-")
            train_predictor(args, controller_trainer, nas_bench, random_arch)
            log.info(20 * '-' + "Finishing training!" + 20 * '-')
            controller_trainer.buffer_unique_archs["pred_info"] = []

        controller_trainer.cur_total_time_costs += time.time() - start_time

    best_arch_info = controller_trainer.get_cur_best_arch()
    v = list(best_arch_info.values())[0]
    best_arch, val_acc, test_acc = v[0], v[1], v[2]

    return best_arch, val_acc, test_acc, controller_trainer.cur_total_time_costs


if __name__ == "__main__":
    args = build_args()

    base_path = args.log_path + '/' + \
                args.bench_name + '/' + \
                args.dataset + '_' + \
                args.update_agent_algo + '_' + \
                args.is_predictor + '_' + \
                args.predictor_mode + '_' + \
                str(args.time_budget) + '/'

    if not os.path.exists(base_path):
        os.makedirs(base_path)
    path_res = base_path + 'results.txt'

    if args.rand_seed < 0:
        total_time_costs, val_accs, test_accs = [], [], []

        for i in range(args.loops_if_rand):
            args.rand_seed = random.randint(1, 100000)
            best_arch_str, val_acc, test_acc, total_time_cost = main(args, base_path)
            val_accs.append(val_acc)
            test_accs.append(test_acc)
            total_time_costs.append(total_time_cost)

            with open(path_res, 'a') as f:
                f.write(
                    "Rand_seed: {}, Num_iter: {}, val_acc: {}, test_acc: {}, total_time_cost: {}, best_arch_str: {}\n".format(
                        args.rand_seed,
                        i + 1,
                        val_acc,
                        test_acc,
                        total_time_cost,
                        best_arch_str
                    )
                )

        with open(path_res, 'a') as f:
            f.write("All best arch's mean val_acc: {}, std: {}\n".format(np.mean(val_accs), np.std(val_accs)))
            f.write("All best arch's mean test_acc: {}, std: {}\n".format(np.mean(test_accs), np.std(test_accs)))
            f.write("All best arch's mean total_time_cost: {}, std: {}\n".format(np.mean(total_time_costs), np.std(total_time_costs)))

    else:
        best_arch_str, val_acc, test_acc, total_time_cost = main(args)
        with open(path_res, 'a') as f:
            f.write(
                "Rand_seed: {}, Num_iter: {}, val_acc: {}, test_acc: {}, total_time_cost: {}, best_arch_str: {}\n".format(
                    args.rand_seed,
                    1,
                    val_acc,
                    test_acc,
                    total_time_cost,
                    best_arch_str
                )
            )



